// Copyright 2015-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
//	http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package entity

import (
	"testing"

	"github.com/aws/amazon-ecs-cli/ecs-cli/modules/cli/compose/context"
	"github.com/aws/amazon-ecs-cli/ecs-cli/modules/clients/aws/ec2/mock"
	ecsClient "github.com/aws/amazon-ecs-cli/ecs-cli/modules/clients/aws/ecs"
	"github.com/aws/amazon-ecs-cli/ecs-cli/modules/clients/aws/ecs/mock"
	"github.com/aws/amazon-ecs-cli/ecs-cli/modules/config"
	composeutils "github.com/aws/amazon-ecs-cli/ecs-cli/modules/utils/compose"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/ec2"
	"github.com/aws/aws-sdk-go/service/ecs"
	"github.com/docker/libcompose/project"
	"github.com/golang/mock/gomock"
	"github.com/sirupsen/logrus"
	"github.com/stretchr/testify/assert"
)

type validateListTasksInput func(*ecs.ListTasksInput, string, *testing.T)
type setupEntityForTestInfo func(*context.ECSContext) ProjectEntity

// TestInfo tests ps commands
func TestInfo(setupEntity setupEntityForTestInfo, validateFunc validateListTasksInput, t *testing.T, filterLocal bool, desiredStatus string) {
	projectName := "project"
	containerInstance := "containerInstance"
	ec2InstanceID := "ec2InstanceID"
	ec2Instance := &ec2.Instance{
		PublicIpAddress: aws.String("publicIpAddress"),
	}
	taskDefArn := "arn:123456:taskdefinition/mytaskdef"
	taskDef := &ecs.TaskDefinition{
		TaskDefinitionArn: aws.String(taskDefArn),
		NetworkMode:       aws.String(ecs.NetworkModeHost),
		ContainerDefinitions: []*ecs.ContainerDefinition{
			&ecs.ContainerDefinition{
				Name: aws.String("contName"),
				PortMappings: []*ecs.PortMapping{
					&ecs.PortMapping{
						ContainerPort: aws.Int64(80),
						HostPort:      aws.Int64(80),
						Protocol:      aws.String("tcp"),
					},
				},
			},
		},
	}

	instanceIdsMap := make(map[string]string)
	instanceIdsMap[containerInstance] = ec2InstanceID

	ec2InstancesMap := make(map[string]*ec2.Instance)
	ec2InstancesMap[ec2InstanceID] = ec2Instance

	container := &ecs.Container{
		Name:         aws.String("contName"),
		ContainerArn: aws.String("contArn/contId"),
		LastStatus:   aws.String("lastStatus"),
	}

	ecsTask := &ecs.Task{
		TaskDefinitionArn:    aws.String(taskDefArn),
		TaskArn:              aws.String("taskArn/taskId"),
		Containers:           []*ecs.Container{container},
		ContainerInstanceArn: aws.String(containerInstance),
		Group:                aws.String(composeutils.GetTaskGroup("", projectName)),
	}

	// Deperecated
	ecsTaskWithStartedBy := &ecs.Task{
		TaskDefinitionArn:    aws.String(taskDefArn),
		TaskArn:              aws.String("taskArn/taskId"),
		Containers:           []*ecs.Container{container},
		ContainerInstanceArn: aws.String(containerInstance),
		StartedBy:            aws.String(projectName),
	}

	runningTasks := []*ecs.Task{ecsTask}
	stoppedTasks := []*ecs.Task{ecsTask, ecsTaskWithStartedBy}

	ctrl := gomock.NewController(t)
	defer ctrl.Finish()
	mockEcs := mock_ecs.NewMockECSClient(ctrl)
	mockEc2 := mock_ec2.NewMockEC2Client(ctrl)

	var expectedCalls []*gomock.Call

	logrus.Info("desiredStatus in TestInfo: " + desiredStatus)

	if desiredStatus == "" || desiredStatus == ecs.DesiredStatusRunning {
		expectedCalls = append(expectedCalls,
			mockEcs.EXPECT().GetTasksPages(gomock.Any(), gomock.Any()).Do(func(x, y interface{}) {
				logrus.Info("Running tasks call")
				// verify input fields
				req := x.(*ecs.ListTasksInput)
				validateFunc(req, projectName, t)
				assert.Equal(t, ecs.DesiredStatusRunning, aws.StringValue(req.DesiredStatus), "Expected DesiredStatus to be RUNNING")

				// execute the function passed as input
				funct := y.(ecsClient.ProcessTasksAction)
				funct(runningTasks)
			}).Return(nil),
		)
	}

	if desiredStatus == "" || desiredStatus == ecs.DesiredStatusStopped {
		expectedCalls = append(expectedCalls,
			mockEcs.EXPECT().GetTasksPages(gomock.Any(), gomock.Any()).Do(func(x, y interface{}) {
				logrus.Info("Stopped tasks call")
				// verify input fields
				req := x.(*ecs.ListTasksInput)
				validateFunc(req, projectName, t)
				assert.Equal(t, ecs.DesiredStatusStopped, aws.StringValue(req.DesiredStatus), "Expected DesiredStatus to be STOPPED")

				// execute the function passed as input
				funct := y.(ecsClient.ProcessTasksAction)
				funct(stoppedTasks)
			}).Return(nil),
		)
	}

	logrus.Infof("Len of expected list calls: %d", len(expectedCalls))

	expectedCalls = append(expectedCalls,
		mockEcs.EXPECT().DescribeTaskDefinition(taskDefArn).Return(taskDef, nil),
		mockEcs.EXPECT().GetEC2InstanceIDs([]*string{&containerInstance}).Return(instanceIdsMap, nil),
		mockEc2.EXPECT().DescribeInstances([]*string{&ec2InstanceID}).Return(ec2InstancesMap, nil),
	)

	gomock.InOrder(
		expectedCalls...,
	)

	context := &context.ECSContext{
		ECSClient:     mockEcs,
		EC2Client:     mockEc2,
		CommandConfig: &config.CommandConfig{},
		Context: project.Context{
			ProjectName: projectName,
		},
	}
	entity := setupEntity(context)
	infoSet, err := entity.Info(filterLocal, desiredStatus)
	assert.NoError(t, err, "Unexpected error when getting info")

	var expectedCountOfContainers int
	if desiredStatus == ecs.DesiredStatusRunning {
		expectedCountOfContainers = len(runningTasks)
	} else if desiredStatus == ecs.DesiredStatusStopped {
		expectedCountOfContainers = len(stoppedTasks)
	} else if desiredStatus == "" {
		expectedCountOfContainers = len(runningTasks) + len(stoppedTasks)
	}
	assert.Len(t, infoSet, expectedCountOfContainers, "Expected containers count to match")
}
